# ------------------------------------------------------------------------------
# Gets train set predicted values using ensemble model for each fold
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 07_predict_ensemble.sh {split} {restriction}
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(yaml)
library(data.table)
library(tidyverse)
library(glue)
library(Matrix)
library(glmnet)
library(optparse)
library(here)
library(testit) # assert()

u <- modules::use(here("lib", "util.R"))

# Command Line Args ------------------------------------------------------------
arg_config <- list(
  make_option("--split", type = "character"),
  make_option("--restriction", type = "character")
)
arg_parser <- OptionParser(option_list = arg_config)
arg_list <- parse_args(arg_parser)

# Helpers ----------------------------------------------------------------------
get_score_name <- function(model_file) {
  clean <- str_remove(model_file, "__[12345].rds")
  restriction_lab <- ifelse(
    arg_list$restriction == "all", "", glue("__{arg_list$restriction}")
  )
  glue("p__{clean}{restriction_lab}")
}

# Directories ------------------------------------------------------------------
message("Establishing directories...")
paths <- read_yaml(here("lib", "filepaths.yml"))
subscores_dir <-  file.path(
  paths$modeling$oos, "subscores", arg_list$split, arg_list$restriction
)
models_dir <- file.path(
  paths$modeling$oos, "models", arg_list$split, arg_list$restriction
)
prediction_dir <- file.path(
  paths$modeling$oos, "prediction", arg_list$split, arg_list$restriction
)
assert("Predictions directory exists", dir.exists(prediction_dir))

# Locate Models ----------------------------------------------------------------
message("Locating ensemble models...")
model_files <- list.files(models_dir)
ensemble_files <- str_subset(model_files, "ensemble__")

score_colnames <- c("ed_enc_id", "train_fold") %>%
  union(map(ensemble_files, get_score_name) %>% unlist())
scores <- data.frame(matrix(ncol = length(score_colnames), nrow = 0))
colnames(scores) <- score_colnames

# Prediction -------------------------------------------------------------------
for (i in 1:5) {
  message("Predicting ensemble for fold ", i, "...")
  subscores <- readRDS(file.path(subscores_dir, glue("subscores__{i}.rds")))
  fold_scores <- subscores %>%
    filter(train_fold == i) %>%
    setDT()
  fold_ensemble_files <- str_subset(ensemble_files, glue("__{i}"))

  walk(fold_ensemble_files, ~ {
    model <- readRDS(file.path(models_dir, .x))
    score_name <- get_score_name(.x)
    fold_scores[, (score_name) := predict(model, fold_scores, type = "response")]
  })
  fold_scores <- fold_scores %>%
    select(all_of(score_colnames))
  scores <- rbind(scores, fold_scores)
}

# Save -------------------------------------------------------------------------
message("Saving...")
write_rds(scores, file.path(prediction_dir, "scores_train_set.rds"))

message("Done.")
